import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Activation, MaxPool2D, Concatenate
import scipy
from skimage import transform
from skimage import io
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import tifffile
def plot_all(image):
index = 0
count = 1
fig = plt.figure(figsize = (15, 30))
for x in range(1, 17):
for y in range(1, 9):
plt.subplot(16, 8, count).axis("off")
plt.title("Count: " + str(count-1))
plt.imshow(image[:,:,index], cmap='gray')
count += 1
index += 1
CROP_RATE = 0.2
def cropScan2(scan):
col_size = scan.shape[0]
row_size = scan.shape[1]
newScan = scan[int(CROP_RATE*col_size):int((1-CROP_RATE)*col_size),
int(CROP_RATE*row_size):int((1-CROP_RATE)*row_size),:]
return newScan
''' Read Images/Masks '''
images = tifffile.imread('ivc_filter_images_84.tif')
masks = tifffile.imread("ivc_filter_masks_84.tif")
print(images.shape)
print(masks.shape)
(84, 256, 256, 128) (84, 256, 256, 128)
''' Crop 20% on all sides'''
final_images = []
for img in images:
final_images.append(cropScan2(img/255))
final_masks = []
for mask in masks:
mask = mask/mask.max()
final_masks.append(cropScan2(mask))
final_images = np.asarray(final_images)
final_masks = np.asarray(final_masks)
print(final_images.shape, final_masks.shape)
print(np.unique(final_masks))
(84, 153, 153, 128) (84, 153, 153, 128) [0. 1.]
''' Resize to (128, 128, 128) '''
final_images = resize(final_images, (84, 128, 128, 128))
final_masks = resize(final_masks, (84, 128, 128, 128))
print(final_images.shape, final_masks.shape)
print(np.unique(final_masks))
(84, 128, 128, 128) (84, 128, 128, 128) [0.00000000e+00 1.37329102e-04 2.28881836e-04 ... 9.99649048e-01 9.99771118e-01 1.00000000e+00]
fig = plt.figure(figsize = (12, 12))
plt.subplot(2, 2, 1).axis("off")
plt.title("Original Image")
plt.imshow(images[0][:,:,54], cmap='gray')
plt.subplot(2, 2, 2).axis("off")
plt.title("Augmented Image")
plt.imshow(final_images[0][:,:,54], cmap='gray')
plt.subplot(2, 2, 3).axis("off")
plt.title("Original Mask")
plt.imshow(masks[0][:,:,54])
plt.subplot(2, 2, 4).axis("off")
plt.title("Augmented Mask")
plt.imshow(final_masks[0][:,:,54])
<matplotlib.image.AxesImage at 0x7f49b2307e20>
plot_all(images[0])
# Check shapes & Splits into training/testing
x_train, x_test, y_train, y_test = train_test_split(final_images, final_masks, test_size=0.20, random_state=7)
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
(67, 128, 128, 128) (67, 128, 128, 128) (17, 128, 128, 128) (17, 128, 128, 128)
''' Expand_dims and One Hot Encoding'''
x_train = np.expand_dims(np.asarray(x_train), axis = -1)
x_test = np.expand_dims(np.asarray(x_test), axis = -1)
y_train = to_categorical(np.asarray(y_train))
y_test = to_categorical(np.asarray(y_test))
print(np.asarray(x_train).shape)
print(np.asarray(y_train).shape)
print(np.asarray(x_test).shape)
print(np.asarray(y_test).shape)
(67, 128, 128, 128, 1) (67, 128, 128, 128, 2) (17, 128, 128, 128, 1) (17, 128, 128, 128, 2)
#Define parameters for our model.
channels=1
LR = 0.0001
optim = keras.optimizers.Adam(LR)
def conv_block(input, num_filters):
x = Conv3D(num_filters, 3, padding="same")(input)
x = BatchNormalization()(x) #Not in the original network.
x = Activation("relu")(x)
x = Conv3D(num_filters, 3, padding="same")(x)
x = BatchNormalization()(x) #Not in the original network
x = Activation("relu")(x)
return x
#Encoder block: Conv block followed by maxpooling
def encoder_block(input, num_filters):
x = conv_block(input, num_filters)
p = MaxPooling3D((2, 2, 2))(x)
return x, p
#Decoder block
#skip features gets input from encoder for concatenation
def decoder_block(input, skip_features, num_filters):
x = Conv3DTranspose(num_filters, (2, 2, 2), strides=2, padding="same")(input)
x = Concatenate()([x, skip_features])
x = conv_block(x, num_filters)
return x
#Build Unet using the blocks
def build_unet(input_shape, n_classes):
inputs = Input(input_shape)
s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
# s4, p4 = encoder_block(p3, 256)
b1 = conv_block(p3, 512) #Bridge
# d1 = decoder_block(b1, s4, 256)
d2 = decoder_block(b1, s3, 256)
d3 = decoder_block(d2, s2, 128)
d4 = decoder_block(d3, s1, 64)
if n_classes == 1: #Binary
activation = 'sigmoid'
else:
activation = 'softmax'
outputs = Conv3D(n_classes, 1, padding="same", activation=activation)(d4) #Change the activation based on n_classes
print(activation)
model = Model(inputs, outputs, name="U-Net")
return model
METRICS = [
tf.keras.metrics.TruePositives(name='tp'),
tf.keras.metrics.FalsePositives(name='fp'),
tf.keras.metrics.TrueNegatives(name='tn'),
tf.keras.metrics.FalseNegatives(name='fn'),
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]
model = build_unet((128, 128, 128, 1), n_classes=2)
model.compile(optimizer = optim, loss=tf.keras.losses.CategoricalCrossentropy(), metrics=METRICS)
print(model.summary())
softmax
Model: "U-Net"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 128, 128, 12 0
__________________________________________________________________________________________________
conv3d (Conv3D) (None, 128, 128, 128 1792 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 128 256 conv3d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 128, 128, 128 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv3d_1 (Conv3D) (None, 128, 128, 128 110656 activation[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 128 256 conv3d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 128, 128, 128 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 64, 64, 64, 6 0 activation_1[0][0]
__________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, 64, 64, 64, 1 221312 max_pooling3d[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 64, 1 512 conv3d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 64, 64, 64, 1 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv3d_3 (Conv3D) (None, 64, 64, 64, 1 442496 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 64, 1 512 conv3d_3[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 64, 64, 64, 1 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
max_pooling3d_1 (MaxPooling3D) (None, 32, 32, 32, 1 0 activation_3[0][0]
__________________________________________________________________________________________________
conv3d_4 (Conv3D) (None, 32, 32, 32, 2 884992 max_pooling3d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 32, 32, 32, 2 1024 conv3d_4[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 32, 32, 32, 2 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv3d_5 (Conv3D) (None, 32, 32, 32, 2 1769728 activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 32, 2 1024 conv3d_5[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 32, 32, 32, 2 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
max_pooling3d_2 (MaxPooling3D) (None, 16, 16, 16, 2 0 activation_5[0][0]
__________________________________________________________________________________________________
conv3d_6 (Conv3D) (None, 16, 16, 16, 5 3539456 max_pooling3d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 16, 5 2048 conv3d_6[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 16, 16, 16, 5 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv3d_7 (Conv3D) (None, 16, 16, 16, 5 7078400 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 16, 16, 16, 5 2048 conv3d_7[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 16, 16, 16, 5 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv3d_transpose (Conv3DTranspo (None, 32, 32, 32, 2 1048832 activation_7[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 32, 32, 32, 5 0 conv3d_transpose[0][0]
activation_5[0][0]
__________________________________________________________________________________________________
conv3d_8 (Conv3D) (None, 32, 32, 32, 2 3539200 concatenate[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 32, 2 1024 conv3d_8[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 32, 32, 32, 2 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv3d_9 (Conv3D) (None, 32, 32, 32, 2 1769728 activation_8[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 32, 2 1024 conv3d_9[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 32, 32, 32, 2 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv3d_transpose_1 (Conv3DTrans (None, 64, 64, 64, 1 262272 activation_9[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 64, 64, 64, 2 0 conv3d_transpose_1[0][0]
activation_3[0][0]
__________________________________________________________________________________________________
conv3d_10 (Conv3D) (None, 64, 64, 64, 1 884864 concatenate_1[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 64, 1 512 conv3d_10[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 64, 64, 64, 1 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv3d_11 (Conv3D) (None, 64, 64, 64, 1 442496 activation_10[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 64, 1 512 conv3d_11[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 64, 64, 64, 1 0 batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv3d_transpose_2 (Conv3DTrans (None, 128, 128, 128 65600 activation_11[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 128, 128, 128 0 conv3d_transpose_2[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
conv3d_12 (Conv3D) (None, 128, 128, 128 221248 concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 128, 128, 128 256 conv3d_12[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 128, 128, 128 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv3d_13 (Conv3D) (None, 128, 128, 128 110656 activation_12[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 128, 128, 128 256 conv3d_13[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 128, 128, 128 0 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv3d_14 (Conv3D) (None, 128, 128, 128 130 activation_13[0][0]
==================================================================================================
Total params: 22,405,122
Trainable params: 22,399,490
Non-trainable params: 5,632
__________________________________________________________________________________________________
None
'''Checks'''
print("Input shape", model.input_shape)
print("Output shape", model.output_shape)
print("-------------------")
Input shape (None, 128, 128, 128, 1) Output shape (None, 128, 128, 128, 2) -------------------
history=model.fit(x_train, y_train,
validation_data=(x_test, y_test),
batch_size=1,
epochs=100,
shuffle=True,
verbose=1)
Epoch 1/100 67/67 [==============================] - 128s 2s/step - loss: 0.2433 - tp: 70665370.2647 - fp: 1655533.2353 - tn: 70665370.3088 - fn: 1655534.1618 - accuracy: 0.9621 - precision: 0.9621 - recall: 0.9621 - auc: 0.9815 - prc: 0.9721 - val_loss: 0.4283 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9968 - val_prc: 0.9950 Epoch 2/100 67/67 [==============================] - 108s 2s/step - loss: 0.0531 - tp: 72154285.7500 - fp: 166623.1471 - tn: 72154285.7500 - fn: 166623.1471 - accuracy: 0.9976 - precision: 0.9976 - recall: 0.9976 - auc: 0.9987 - prc: 0.9983 - val_loss: 0.0891 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9967 - val_prc: 0.9951 Epoch 3/100 67/67 [==============================] - 108s 2s/step - loss: 0.0365 - tp: 72162598.8824 - fp: 158310.0735 - tn: 72162598.8824 - fn: 158310.0735 - accuracy: 0.9978 - precision: 0.9978 - recall: 0.9978 - auc: 0.9995 - prc: 0.9994 - val_loss: 0.0297 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9973 - val_prc: 0.9964 Epoch 4/100 67/67 [==============================] - 108s 2s/step - loss: 0.0283 - tp: 72197312.9559 - fp: 123597.2941 - tn: 72197312.9559 - fn: 123597.2941 - accuracy: 0.9982 - precision: 0.9982 - recall: 0.9982 - auc: 0.9999 - prc: 0.9999 - val_loss: 0.0461 - val_tp: 35557960.0000 - val_fp: 93625.0000 - val_tn: 35557960.0000 - val_fn: 93625.0000 - val_accuracy: 0.9974 - val_precision: 0.9974 - val_recall: 0.9974 - val_auc: 0.9979 - val_prc: 0.9976 Epoch 5/100 67/67 [==============================] - 108s 2s/step - loss: 0.0253 - tp: 72234025.4265 - fp: 86883.6029 - tn: 72234025.4265 - fn: 86883.6029 - accuracy: 0.9988 - precision: 0.9988 - recall: 0.9988 - auc: 0.9999 - prc: 0.9999 - val_loss: 0.0319 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9982 - val_prc: 0.9980 Epoch 6/100 67/67 [==============================] - 108s 2s/step - loss: 0.0199 - tp: 72242459.7647 - fp: 78438.6029 - tn: 72242459.7647 - fn: 78438.6029 - accuracy: 0.9988 - precision: 0.9988 - recall: 0.9988 - auc: 1.0000 - prc: 0.9999 - val_loss: 0.0263 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9981 - val_prc: 0.9979 Epoch 7/100 67/67 [==============================] - 108s 2s/step - loss: 0.0172 - tp: 72243741.8676 - fp: 77156.5294 - tn: 72243741.8676 - fn: 77156.5294 - accuracy: 0.9990 - precision: 0.9990 - recall: 0.9990 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0265 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9986 - val_prc: 0.9984 Epoch 8/100 67/67 [==============================] - 108s 2s/step - loss: 0.0145 - tp: 72254469.0882 - fp: 66439.4265 - tn: 72254469.0882 - fn: 66439.4265 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0222 - val_tp: 35563260.0000 - val_fp: 88324.0000 - val_tn: 35563260.0000 - val_fn: 88324.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9991 - val_prc: 0.9990 Epoch 9/100 67/67 [==============================] - 108s 2s/step - loss: 0.0128 - tp: 72256076.2059 - fp: 64826.4559 - tn: 72256076.2059 - fn: 64826.4559 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0176 - val_tp: 35574656.0000 - val_fp: 76930.0000 - val_tn: 35574656.0000 - val_fn: 76930.0000 - val_accuracy: 0.9978 - val_precision: 0.9978 - val_recall: 0.9978 - val_auc: 0.9996 - val_prc: 0.9996 Epoch 10/100 67/67 [==============================] - 108s 2s/step - loss: 0.0120 - tp: 72256378.6618 - fp: 64520.3382 - tn: 72256378.6618 - fn: 64520.3382 - accuracy: 0.9990 - precision: 0.9990 - recall: 0.9990 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0162 - val_tp: 35578904.0000 - val_fp: 72679.0000 - val_tn: 35578904.0000 - val_fn: 72679.0000 - val_accuracy: 0.9980 - val_precision: 0.9980 - val_recall: 0.9980 - val_auc: 0.9996 - val_prc: 0.9995 Epoch 11/100 67/67 [==============================] - 109s 2s/step - loss: 0.0105 - tp: 72257807.3235 - fp: 63087.6912 - tn: 72257807.3235 - fn: 63087.6912 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0143 - val_tp: 35583928.0000 - val_fp: 67649.0000 - val_tn: 35583928.0000 - val_fn: 67649.0000 - val_accuracy: 0.9981 - val_precision: 0.9981 - val_recall: 0.9981 - val_auc: 0.9997 - val_prc: 0.9996 Epoch 12/100 67/67 [==============================] - 109s 2s/step - loss: 0.0096 - tp: 72262295.7941 - fp: 58608.5294 - tn: 72262295.7941 - fn: 58608.5294 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0119 - val_tp: 35597504.0000 - val_fp: 54082.0000 - val_tn: 35597504.0000 - val_fn: 54082.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 0.9999 - val_prc: 0.9999 Epoch 13/100 67/67 [==============================] - 109s 2s/step - loss: 0.0091 - tp: 72258746.1029 - fp: 62154.3088 - tn: 72258746.1029 - fn: 62154.3088 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0096 - val_tp: 35612580.0000 - val_fp: 39006.0000 - val_tn: 35612580.0000 - val_fn: 39006.0000 - val_accuracy: 0.9989 - val_precision: 0.9989 - val_recall: 0.9989 - val_auc: 1.0000 - val_prc: 0.9999 Epoch 14/100 67/67 [==============================] - 109s 2s/step - loss: 0.0080 - tp: 72267470.3971 - fp: 53431.3529 - tn: 72267470.3971 - fn: 53431.3529 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0104 - val_tp: 35600968.0000 - val_fp: 50617.0000 - val_tn: 35600968.0000 - val_fn: 50617.0000 - val_accuracy: 0.9986 - val_precision: 0.9986 - val_recall: 0.9986 - val_auc: 0.9998 - val_prc: 0.9997 Epoch 15/100 67/67 [==============================] - 109s 2s/step - loss: 0.0079 - tp: 72261981.5735 - fp: 58922.7206 - tn: 72261981.5735 - fn: 58922.7206 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0087 - val_tp: 35606428.0000 - val_fp: 45154.0000 - val_tn: 35606428.0000 - val_fn: 45154.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9999 - val_prc: 0.9999 Epoch 16/100 67/67 [==============================] - 109s 2s/step - loss: 0.0069 - tp: 72267799.3529 - fp: 53111.0735 - tn: 72267799.3529 - fn: 53111.0735 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0097 - val_tp: 35605944.0000 - val_fp: 45643.0000 - val_tn: 35605944.0000 - val_fn: 45643.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9998 - val_prc: 0.9997 Epoch 17/100 67/67 [==============================] - 109s 2s/step - loss: 0.0067 - tp: 72263938.5294 - fp: 56963.2500 - tn: 72263938.5294 - fn: 56963.2500 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0081 - val_tp: 35607352.0000 - val_fp: 44230.0000 - val_tn: 35607352.0000 - val_fn: 44230.0000 - val_accuracy: 0.9988 - val_precision: 0.9988 - val_recall: 0.9988 - val_auc: 1.0000 - val_prc: 0.9999 Epoch 18/100 67/67 [==============================] - 108s 2s/step - loss: 0.0059 - tp: 72272285.5147 - fp: 48616.2500 - tn: 72272285.5147 - fn: 48616.2500 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0095 - val_tp: 35597456.0000 - val_fp: 54130.0000 - val_tn: 35597456.0000 - val_fn: 54130.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 0.9997 - val_prc: 0.9997 Epoch 19/100 67/67 [==============================] - 108s 2s/step - loss: 0.0054 - tp: 72274575.7647 - fp: 46329.9118 - tn: 72274575.7647 - fn: 46329.9118 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0076 - val_tp: 35607356.0000 - val_fp: 44227.0000 - val_tn: 35607356.0000 - val_fn: 44227.0000 - val_accuracy: 0.9988 - val_precision: 0.9988 - val_recall: 0.9988 - val_auc: 0.9999 - val_prc: 0.9998 Epoch 20/100 67/67 [==============================] - 108s 2s/step - loss: 0.0054 - tp: 72271531.6618 - fp: 49375.7353 - tn: 72271531.6618 - fn: 49375.7353 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0071 - val_tp: 35599040.0000 - val_fp: 52543.0000 - val_tn: 35599040.0000 - val_fn: 52543.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 1.0000 - val_prc: 1.0000 Epoch 21/100 67/67 [==============================] - 109s 2s/step - loss: 0.0049 - tp: 72272039.5294 - fp: 48863.3235 - tn: 72272039.5294 - fn: 48863.3235 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0078 - val_tp: 35603876.0000 - val_fp: 47705.0000 - val_tn: 35603876.0000 - val_fn: 47705.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9997 - val_prc: 0.9997 Epoch 22/100 13/67 [====>.........................] - ETA: 1:20 - loss: 0.0046 - tp: 14671562.1538 - fp: 8501.6923 - tn: 14671562.1538 - fn: 8501.6923 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000
#Save model for future use
model.save('3D_UNet_no_patch_1.h5')
def plot_metrics(history):
metrics = ['loss', 'prc', 'precision', 'recall']
fig = plt.figure(figsize = (8, 8))
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color='blue', label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color='blue', linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1.1])
else:
plt.ylim([0,1.1])
plt.legend();
plot_metrics(history)
#Load the pretrained model for testing and predictions.
from tensorflow.keras.models import load_model
my_model = load_model('3D_UNet_no_patch_normalized.h5', compile=False)
#If you load a different model do not forget to preprocess accordingly.
''' Get the train and test sets after crop & resize '''
x_train_orig, x_test_orig, y_train_orig, y_test_orig = train_test_split(final_images, final_masks, test_size=0.20, random_state=7)
print(x_train_orig.shape, y_train_orig.shape)
print(x_test_orig.shape, y_test_orig.shape)
(67, 128, 128, 128) (67, 128, 128, 128) (17, 128, 128, 128) (17, 128, 128, 128)
''' PREDICT on Training Set '''
img = np.expand_dims(x_train_orig, axis=-1)
print(img.shape)
ground_truth = y_train_orig
print(ground_truth.shape)
end = []
# Prediction on each individual image because i was getting
# a resource exhaust error when i tried to predict on the whole batch
for i in img:
end.append(my_model.predict(np.expand_dims(i, axis=0)))
end = np.asarray(end)
end = np.squeeze(end)
print(end.shape)
train_prediction = np.argmax(end, axis=4)[:,:,:,:]
print(train_prediction.shape)
(67, 128, 128, 128, 1) (67, 128, 128, 128) (67, 128, 128, 128, 2) (67, 128, 128, 128)
''' MEAN IOU on Training Set '''
from tensorflow.keras.metrics import MeanIoU
n_classes = 2
IOU_keras = MeanIoU(num_classes=n_classes)
gt1 = ground_truth.astype("int32")
IOU_keras.update_state(gt1, train_prediction)
print("Training Set:", IOU_keras.result().numpy())
Training Set: 0.90223163
''' Training Set Per Pixel Basis '''
from sklearn.metrics import confusion_matrix
import seaborn as sns
y_train_matrix = confusion_matrix(np.asarray(gt1).flatten(), np.asarray(train_prediction).flatten())
ax= plt.subplot()
sns.heatmap(y_train_matrix, annot=True, fmt='d', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation
# labels, title and ticks
ax.set_xlabel('Predicted labels');
ax.set_ylabel('True labels');
ax.set_title('Train Set Per Pixel Basis');
ax.xaxis.set_ticklabels(['0', '1']); ax.yaxis.set_ticklabels(['0', '1']);
''' PREDICT on Testing Set '''
# Image from Testing Set
img = np.expand_dims(x_test_orig, axis=-1)
end = []
for i in img:
end.append(my_model.predict(np.expand_dims(i, axis=0)))
end = np.asarray(end)
end = np.squeeze(end)
print(end.shape)
test_prediction = np.argmax(end, axis=4)[:,:,:,:]
print(test_prediction.shape)
ground_truth = y_test_orig.astype("int32")
print(ground_truth.shape)
print(np.unique(ground_truth))
(17, 128, 128, 128, 2) (17, 128, 128, 128) (17, 128, 128, 128) [0 1]
''' MEAN IOU on Testing Set'''
from tensorflow.keras.metrics import MeanIoU
n_classes = 2
IOU_keras = MeanIoU(num_classes=n_classes)
gt2 = ground_truth
IOU_keras.update_state(gt2, test_prediction)
print("Testing Set:", IOU_keras.result().numpy())
Testing Set: 0.80078983
''' Test Set Per Pixel Based '''
y_test_matrix = confusion_matrix(np.asarray(gt2).flatten(), np.asarray(test_prediction).flatten())
ax= plt.subplot()
sns.heatmap(y_test_matrix, annot=True, fmt='d', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation
# labels, title and ticks
ax.set_xlabel('Predicted labels');
ax.set_ylabel('True labels');
ax.set_title('Test Set Per Pixel Basis');
ax.xaxis.set_ticklabels(['0', '1']); ax.yaxis.set_ticklabels(['0', '1']);
def plot_three(image, mask, predicted):
index = 0
count = 1
fig = plt.figure(figsize = (12, 500))
for x in range(1, 129):
for y in range(1, 4):
plt.subplot(128, 3, count).axis("off")
if count % 3 == 1:
plt.title("Image Slice: " + str(index))
plt.imshow(image[:,:,index], cmap='gray')
elif count % 3 == 2:
plt.title("Mask Slice: " + str(index))
plt.imshow(mask[:,:,index])
else:
plt.title("Predicted Mask Slice: " + str(index))
plt.imshow(predicted[:,:,index])
index += 1
count += 1
plot_three(x_test_orig[3], y_test_orig[3], test_prediction[3])